{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# task4：卷积情感分析 \n",
    "\n",
    "在本节中，我们将利用卷积神经网络（CNN）进行情感分析，实现 [Convolutional Neural Networks for Sentence Classification](https://arxiv.org/abs/1408.5882)中的模型。\n",
    "\n",
    "**注**：本次组队学习的目的不会全面介绍和解释CNN。要想学习更对相关知识可以请查看[此处](https://ujjwalkarn.me/2016/08/11/intuitive-explanation-convnets/)和[这里](https://cs231n.github.io/convolutional-networks/)。\n",
    "\n",
    "卷积神经网络在计算机视觉问题上表现出色，原因在于其能够从局部输入图像块中提取特征，并能将表示模块化，同时可以高效地利用数据。同样的，卷积神经网络也可以用于处理序列数据，时间可以被看作一个空间维度，就像二维图像的高度和宽度。\n",
    "\n",
    "那么为什么要在文本上使用卷积神经网络呢？与3x3 filter可以查看图像块的方式相同，1x2 filter 可以查看一段文本中的两个连续单词，即双字符。在上一个教程中，我们研究了FastText模型，该模型通过将bi-gram显式添加到文本末尾来使用bi-gram，在这个CNN模型中，我们将使用多个不同大小的filter，这些filter将查看文本中的bi-grams（a 1x2 filter）、tri-grams（a 1x3 filter）and/or n-grams（a 1x$n$ filter）。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4.1 数据预处理\n",
    "\n",
    "与 task3 使用FastText模型的方法不同，本节不再需要刻意地创建bi-gram将它们附加到句子末尾。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/ben/miniconda3/envs/pytorch17/lib/python3.8/site-packages/torchtext-0.9.0a0+c38fd42-py3.8-linux-x86_64.egg/torchtext/data/field.py:150: UserWarning: Field class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.\n",
      "  warnings.warn('{} class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.'.format(self.__class__.__name__), UserWarning)\n",
      "/home/ben/miniconda3/envs/pytorch17/lib/python3.8/site-packages/torchtext-0.9.0a0+c38fd42-py3.8-linux-x86_64.egg/torchtext/data/field.py:150: UserWarning: LabelField class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.\n",
      "  warnings.warn('{} class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.'.format(self.__class__.__name__), UserWarning)\n",
      "/home/ben/miniconda3/envs/pytorch17/lib/python3.8/site-packages/torchtext-0.9.0a0+c38fd42-py3.8-linux-x86_64.egg/torchtext/data/example.py:78: UserWarning: Example class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.\n",
      "  warnings.warn('Example class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.', UserWarning)\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "from torchtext.legacy import data\n",
    "from torchtext.legacy import datasets\n",
    "import random\n",
    "import numpy as np\n",
    "\n",
    "SEED = 1234\n",
    "\n",
    "random.seed(SEED)\n",
    "np.random.seed(SEED)\n",
    "torch.manual_seed(SEED)\n",
    "torch.backends.cudnn.deterministic = True\n",
    "\n",
    "TEXT = data.Field(tokenize = 'spacy', \n",
    "                  tokenizer_language = 'en_core_web_sm',\n",
    "                  batch_first = True)\n",
    "LABEL = data.LabelField(dtype = torch.float)\n",
    "\n",
    "train_data, test_data = datasets.IMDB.splits(TEXT, LABEL)\n",
    "\n",
    "train_data, valid_data = train_data.split(random_state = random.seed(SEED))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "构建vocab，加载预训练词嵌入："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "MAX_VOCAB_SIZE = 25_000\n",
    "\n",
    "TEXT.build_vocab(train_data, \n",
    "                 max_size = MAX_VOCAB_SIZE, \n",
    "                 vectors = \"glove.6B.100d\", \n",
    "                 unk_init = torch.Tensor.normal_)\n",
    "\n",
    "LABEL.build_vocab(train_data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "创建迭代器："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/ben/miniconda3/envs/pytorch17/lib/python3.8/site-packages/torchtext-0.9.0a0+c38fd42-py3.8-linux-x86_64.egg/torchtext/data/iterator.py:48: UserWarning: BucketIterator class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.\n",
      "  warnings.warn('{} class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.'.format(self.__class__.__name__), UserWarning)\n"
     ]
    }
   ],
   "source": [
    "BATCH_SIZE = 64\n",
    "\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "\n",
    "train_iterator, valid_iterator, test_iterator = data.BucketIterator.splits(\n",
    "    (train_data, valid_data, test_data), \n",
    "    batch_size = BATCH_SIZE, \n",
    "    device = device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4.2 构建模型\n",
    "\n",
    "开始构建模型！\n",
    "\n",
    "第一个主要问题是如何将CNN用于文本。图像一般是二维的，而文本是一维的。所以我们可以将一段文本中的每个单词沿着一个轴展开，向量中的元素沿着另一个维度展开。如考虑下面2个句子的嵌入句：\n",
    "\n",
    "![](assets/sentiment9.png)\n",
    "\n",
    "然后我们可以使用一个 **[n x emb_dim]** 的filter。这将完全覆盖 $n$ 个words，因为它们的宽度为`emb_dim` 尺寸。考虑下面的图像，我们的单词向量用绿色表示。这里我们有4个词和5维嵌入，创建了一个[4x5] \"image\" 张量。一次覆盖两个词（即bi-grams)）的filter 将是 **[2x5]** filter，以黄色显示，filter 的每个元素都有一个与之相关的 _weight_。此filter 的输出（以红色显示）将是一个实数，它是filter覆盖的所有元素的加权和。\n",
    "\n",
    "![](assets/sentiment12.png)\n",
    "\n",
    "然后，filter  \"down\" 移动图像（或穿过句子）以覆盖下一个bi-gram，并计算另一个输出（weighted sum）。\n",
    "\n",
    "![](assets/sentiment13.png)\n",
    "\n",
    "最后，filter 再次向下移动，并计算此 filter 的最终输出。\n",
    "\n",
    "![](assets/sentiment14.png)\n",
    "\n",
    "一般情况下，filter 的宽度等于\"image\" 的宽度，我们得到的输出是一个向量，其元素数等于图像的高度（或词的长度）减去 filter 的高度加上一。在当前例子中，$4-2+1=3$。\n",
    "\n",
    "上面的例子介绍了如何去计算一个filter的输出。我们的模型（以及几乎所有的CNN）有很多这样的 filter。其思想是，每个filter将学习不同的特征来提取。在上面的例子中，我们希望 **[2 x emb_dim]** filter中的每一个都会查找不同 bi-grams 的出现。\n",
    "\n",
    "在我们的模型中，我们还有不同尺寸的filter，高度为3、4和5，每个filter有100个。我们将寻找与分析电影评论情感相关的不同3-grams, 4-grams 和 5-grams 的情况。\n",
    "\n",
    "我们模型中的下一步是在卷积层的输出上使用pooling（具体是 max pooling）。这类似于FastText模型，不同的是在该模型中，我们计算其最大值，而非是FastText模型中每个词向量进行平均，下面的例子是从卷积层输出中获取得到向量的最大值（0.9）。\n",
    "\n",
    "![](assets/sentiment15.png)\n",
    "\n",
    "最大值是文本情感分析中“最重要”特征，对应于评论中的“最重要”n-gram。由于我们的模型有3种不同大小的100个filters，这意味着我们有300个模型认为重要的不同 n-grams。我们将它们连接成一个向量，并将它们通过线性层来预测最终情感。我们可以将这一线性层的权重视为\"weighting up the evidence\" 的权重，通过综合300个n-gram做出最终预测。\n",
    "\n",
    "\n",
    "### 实施细节\n",
    "\n",
    "1.我们借助 `nn.Conv2d`实现卷积层。`in_channels`参数是图像中进入卷积层的“通道”数。在实际图像中，通常有3个通道（红色、蓝色和绿色通道各有一个通道），但是当使用文本时，我们只有一个通道，即文本本身。`out_channels`是 filters 的数量，`kernel_size`是 filters 的大小。我们的每个“卷积核大小”都将是 **[n x emb_dim]** 其中 $n$ 是n-grams的大小。\n",
    "\n",
    "2.之后，我们通过卷积层和池层传递张量，在卷积层之后使用'ReLU'激活函数。池化层的另一个很好的特性是它们可以处理不同长度的句子。而卷积层的输出大小取决于输入的大小，不同的批次包含不同长度的句子。如果没有最大池层，线性层的输入将取决于输入语句的长度，为了避免这种情况，我们将所有句子修剪/填充到相同的长度，但是线性层来说，线性层的输入一直都是filter的总数。\n",
    "\n",
    "**注**：如果句子的长度小于实验设置的最大filter，那么必须将句子填充到最大filter的长度。在IMDb数据中不会存在这种情况，所以我们不必担心。\n",
    "\n",
    "3.最后，我们对合并之后的filter输出执行dropout操作，然后将它们通过线性层进行预测。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "\n",
    "class CNN(nn.Module):\n",
    "    def __init__(self, vocab_size, embedding_dim, n_filters, filter_sizes, output_dim, \n",
    "                 dropout, pad_idx):\n",
    "        \n",
    "        super().__init__()\n",
    "        \n",
    "        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx = pad_idx)\n",
    "        \n",
    "        self.conv_0 = nn.Conv2d(in_channels = 1, \n",
    "                                out_channels = n_filters, \n",
    "                                kernel_size = (filter_sizes[0], embedding_dim))\n",
    "        \n",
    "        self.conv_1 = nn.Conv2d(in_channels = 1, \n",
    "                                out_channels = n_filters, \n",
    "                                kernel_size = (filter_sizes[1], embedding_dim))\n",
    "        \n",
    "        self.conv_2 = nn.Conv2d(in_channels = 1, \n",
    "                                out_channels = n_filters, \n",
    "                                kernel_size = (filter_sizes[2], embedding_dim))\n",
    "        \n",
    "        self.fc = nn.Linear(len(filter_sizes) * n_filters, output_dim)\n",
    "        \n",
    "        self.dropout = nn.Dropout(dropout)\n",
    "        \n",
    "    def forward(self, text):\n",
    "                \n",
    "        #text = [batch size, sent len]\n",
    "        \n",
    "        embedded = self.embedding(text)\n",
    "                \n",
    "        #embedded = [batch size, sent len, emb dim]\n",
    "        \n",
    "        embedded = embedded.unsqueeze(1)\n",
    "        \n",
    "        #embedded = [batch size, 1, sent len, emb dim]\n",
    "        \n",
    "        conved_0 = F.relu(self.conv_0(embedded).squeeze(3))\n",
    "        conved_1 = F.relu(self.conv_1(embedded).squeeze(3))\n",
    "        conved_2 = F.relu(self.conv_2(embedded).squeeze(3))\n",
    "            \n",
    "        #conved_n = [batch size, n_filters, sent len - filter_sizes[n] + 1]\n",
    "        \n",
    "        pooled_0 = F.max_pool1d(conved_0, conved_0.shape[2]).squeeze(2)\n",
    "        pooled_1 = F.max_pool1d(conved_1, conved_1.shape[2]).squeeze(2)\n",
    "        pooled_2 = F.max_pool1d(conved_2, conved_2.shape[2]).squeeze(2)\n",
    "        \n",
    "        #pooled_n = [batch size, n_filters]\n",
    "        \n",
    "        cat = self.dropout(torch.cat((pooled_0, pooled_1, pooled_2), dim = 1))\n",
    "\n",
    "        #cat = [batch size, n_filters * len(filter_sizes)]\n",
    "            \n",
    "        return self.fc(cat)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "目前，`CNN` 模型使用了3个不同大小的filters，但我们实际上可以改进我们模型的代码，使其更通用，并且可以使用任意数量的filters。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "class CNN(nn.Module):\n",
    "    def __init__(self, vocab_size, embedding_dim, n_filters, filter_sizes, output_dim, \n",
    "                 dropout, pad_idx):\n",
    "        \n",
    "        super().__init__()\n",
    "                \n",
    "        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx = pad_idx)\n",
    "        \n",
    "        self.convs = nn.ModuleList([\n",
    "                                    nn.Conv2d(in_channels = 1, \n",
    "                                              out_channels = n_filters, \n",
    "                                              kernel_size = (fs, embedding_dim)) \n",
    "                                    for fs in filter_sizes\n",
    "                                    ])\n",
    "        \n",
    "        self.fc = nn.Linear(len(filter_sizes) * n_filters, output_dim)\n",
    "        \n",
    "        self.dropout = nn.Dropout(dropout)\n",
    "        \n",
    "    def forward(self, text):\n",
    "                \n",
    "        #text = [batch size, sent len]\n",
    "        \n",
    "        embedded = self.embedding(text)\n",
    "                \n",
    "        #embedded = [batch size, sent len, emb dim]\n",
    "        \n",
    "        embedded = embedded.unsqueeze(1)\n",
    "        \n",
    "        #embedded = [batch size, 1, sent len, emb dim]\n",
    "        \n",
    "        conved = [F.relu(conv(embedded)).squeeze(3) for conv in self.convs]\n",
    "            \n",
    "        #conved_n = [batch size, n_filters, sent len - filter_sizes[n] + 1]\n",
    "                \n",
    "        pooled = [F.max_pool1d(conv, conv.shape[2]).squeeze(2) for conv in conved]\n",
    "        \n",
    "        #pooled_n = [batch size, n_filters]\n",
    "        \n",
    "        cat = self.dropout(torch.cat(pooled, dim = 1))\n",
    "\n",
    "        #cat = [batch size, n_filters * len(filter_sizes)]\n",
    "            \n",
    "        return self.fc(cat)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "还可以使用一维卷积层实现上述模型，其中嵌入维度是 filter 的深度，句子中的token数是宽度。\n",
    "\n",
    "在本task中使用二维卷积模型进行测试，其中的一维模型的实现大家感兴趣的可以自行试一试。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "class CNN1d(nn.Module):\n",
    "    def __init__(self, vocab_size, embedding_dim, n_filters, filter_sizes, output_dim, \n",
    "                 dropout, pad_idx):\n",
    "        \n",
    "        super().__init__()\n",
    "        \n",
    "        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx = pad_idx)\n",
    "        \n",
    "        self.convs = nn.ModuleList([\n",
    "                                    nn.Conv1d(in_channels = embedding_dim, \n",
    "                                              out_channels = n_filters, \n",
    "                                              kernel_size = fs)\n",
    "                                    for fs in filter_sizes\n",
    "                                    ])\n",
    "        \n",
    "        self.fc = nn.Linear(len(filter_sizes) * n_filters, output_dim)\n",
    "        \n",
    "        self.dropout = nn.Dropout(dropout)\n",
    "        \n",
    "    def forward(self, text):\n",
    "        \n",
    "        #text = [batch size, sent len]\n",
    "        \n",
    "        embedded = self.embedding(text)\n",
    "                \n",
    "        #embedded = [batch size, sent len, emb dim]\n",
    "        \n",
    "        embedded = embedded.permute(0, 2, 1)\n",
    "        \n",
    "        #embedded = [batch size, emb dim, sent len]\n",
    "        \n",
    "        conved = [F.relu(conv(embedded)) for conv in self.convs]\n",
    "            \n",
    "        #conved_n = [batch size, n_filters, sent len - filter_sizes[n] + 1]\n",
    "        \n",
    "        pooled = [F.max_pool1d(conv, conv.shape[2]).squeeze(2) for conv in conved]\n",
    "        \n",
    "        #pooled_n = [batch size, n_filters]\n",
    "        \n",
    "        cat = self.dropout(torch.cat(pooled, dim = 1))\n",
    "        \n",
    "        #cat = [batch size, n_filters * len(filter_sizes)]\n",
    "            \n",
    "        return self.fc(cat)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "创建了`CNN` 类的一个实例。\n",
    "\n",
    "如果想运行一维卷积模型，我们可以将`CNN`改为`CNN1d`，注意两个模型给出的结果几乎相同。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "INPUT_DIM = len(TEXT.vocab)\n",
    "EMBEDDING_DIM = 100\n",
    "N_FILTERS = 100\n",
    "FILTER_SIZES = [3,4,5]\n",
    "OUTPUT_DIM = 1\n",
    "DROPOUT = 0.5\n",
    "PAD_IDX = TEXT.vocab.stoi[TEXT.pad_token]\n",
    "\n",
    "model = CNN(INPUT_DIM, EMBEDDING_DIM, N_FILTERS, FILTER_SIZES, OUTPUT_DIM, DROPOUT, PAD_IDX)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "检查我们模型中的参数数量，我们可以看到它与FastText模型大致相同。\n",
    "\n",
    "“CNN”和“CNN1d”模型的参数数量完全相同。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The model has 2,620,801 trainable parameters\n"
     ]
    }
   ],
   "source": [
    "def count_parameters(model):\n",
    "    return sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
    "\n",
    "print(f'The model has {count_parameters(model):,} trainable parameters')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "接下来，加载预训练词嵌入"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[-0.1117, -0.4966,  0.1631,  ...,  1.2647, -0.2753, -0.1325],\n",
       "        [-0.8555, -0.7208,  1.3755,  ...,  0.0825, -1.1314,  0.3997],\n",
       "        [-0.0382, -0.2449,  0.7281,  ..., -0.1459,  0.8278,  0.2706],\n",
       "        ...,\n",
       "        [ 0.6783,  0.0488,  0.5860,  ...,  0.2680, -0.0086,  0.5758],\n",
       "        [-0.6208, -0.0480, -0.1046,  ...,  0.3718,  0.1225,  0.1061],\n",
       "        [-0.6553, -0.6292,  0.9967,  ...,  0.2278, -0.1975,  0.0857]])"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pretrained_embeddings = TEXT.vocab.vectors\n",
    "\n",
    "model.embedding.weight.data.copy_(pretrained_embeddings)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "然后，将未知标记和填充标记的初始权重归零。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "UNK_IDX = TEXT.vocab.stoi[TEXT.unk_token]\n",
    "\n",
    "model.embedding.weight.data[UNK_IDX] = torch.zeros(EMBEDDING_DIM)\n",
    "model.embedding.weight.data[PAD_IDX] = torch.zeros(EMBEDDING_DIM)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4.3 训练模型"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "训练和前面task一样，我们初始化优化器、损失函数（标准），并将模型和标准放置在GPU上。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.optim as optim\n",
    "\n",
    "optimizer = optim.Adam(model.parameters())\n",
    "\n",
    "criterion = nn.BCEWithLogitsLoss()\n",
    "\n",
    "model = model.to(device)\n",
    "criterion = criterion.to(device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "实现了计算精度的函数："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "def binary_accuracy(preds, y):\n",
    "    \"\"\"\n",
    "    Returns accuracy per batch, i.e. if you get 8/10 right, this returns 0.8, NOT 8\n",
    "    \"\"\"\n",
    "\n",
    "    #round predictions to the closest integer\n",
    "    rounded_preds = torch.round(torch.sigmoid(preds))\n",
    "    correct = (rounded_preds == y).float() #convert into float for division \n",
    "    acc = correct.sum() / len(correct)\n",
    "    return acc"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "定义了一个函数来训练我们的模型：\n",
    "\n",
    "**注意**：由于再次使用dropout，我们必须记住使用 `model.train()`以确保在训练时能够使用 dropout 。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(model, iterator, optimizer, criterion):\n",
    "    \n",
    "    epoch_loss = 0\n",
    "    epoch_acc = 0\n",
    "    \n",
    "    model.train()\n",
    "    \n",
    "    for batch in iterator:\n",
    "        \n",
    "        optimizer.zero_grad()\n",
    "        \n",
    "        predictions = model(batch.text).squeeze(1)\n",
    "        \n",
    "        loss = criterion(predictions, batch.label)\n",
    "        \n",
    "        acc = binary_accuracy(predictions, batch.label)\n",
    "        \n",
    "        loss.backward()\n",
    "        \n",
    "        optimizer.step()\n",
    "        \n",
    "        epoch_loss += loss.item()\n",
    "        epoch_acc += acc.item()\n",
    "        \n",
    "    return epoch_loss / len(iterator), epoch_acc / len(iterator)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "定义了一个函数来测试我们的模型：\n",
    "\n",
    "**注意**：同样，由于使用的是dropout，我们必须记住使用`model.eval（）`来确保在评估时能够关闭 dropout。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate(model, iterator, criterion):\n",
    "    \n",
    "    epoch_loss = 0\n",
    "    epoch_acc = 0\n",
    "    \n",
    "    model.eval()\n",
    "    \n",
    "    with torch.no_grad():\n",
    "    \n",
    "        for batch in iterator:\n",
    "\n",
    "            predictions = model(batch.text).squeeze(1)\n",
    "            \n",
    "            loss = criterion(predictions, batch.label)\n",
    "            \n",
    "            acc = binary_accuracy(predictions, batch.label)\n",
    "\n",
    "            epoch_loss += loss.item()\n",
    "            epoch_acc += acc.item()\n",
    "        \n",
    "    return epoch_loss / len(iterator), epoch_acc / len(iterator)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "通过函数得到一个epoch需要多长时间："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "\n",
    "def epoch_time(start_time, end_time):\n",
    "    elapsed_time = end_time - start_time\n",
    "    elapsed_mins = int(elapsed_time / 60)\n",
    "    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))\n",
    "    return elapsed_mins, elapsed_secs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "最后，训练我们的模型："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/ben/miniconda3/envs/pytorch17/lib/python3.8/site-packages/torchtext-0.9.0a0+c38fd42-py3.8-linux-x86_64.egg/torchtext/data/batch.py:23: UserWarning: Batch class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.\n",
      "  warnings.warn('{} class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.'.format(self.__class__.__name__), UserWarning)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 01 | Epoch Time: 0m 13s\n",
      "\tTrain Loss: 0.649 | Train Acc: 61.79%\n",
      "\t Val. Loss: 0.507 |  Val. Acc: 78.93%\n",
      "Epoch: 02 | Epoch Time: 0m 13s\n",
      "\tTrain Loss: 0.433 | Train Acc: 79.86%\n",
      "\t Val. Loss: 0.357 |  Val. Acc: 84.57%\n",
      "Epoch: 03 | Epoch Time: 0m 13s\n",
      "\tTrain Loss: 0.305 | Train Acc: 87.36%\n",
      "\t Val. Loss: 0.312 |  Val. Acc: 86.76%\n",
      "Epoch: 04 | Epoch Time: 0m 13s\n",
      "\tTrain Loss: 0.224 | Train Acc: 91.20%\n",
      "\t Val. Loss: 0.303 |  Val. Acc: 87.16%\n",
      "Epoch: 05 | Epoch Time: 0m 14s\n",
      "\tTrain Loss: 0.159 | Train Acc: 94.16%\n",
      "\t Val. Loss: 0.317 |  Val. Acc: 87.37%\n"
     ]
    }
   ],
   "source": [
    "N_EPOCHS = 5\n",
    "\n",
    "best_valid_loss = float('inf')\n",
    "\n",
    "for epoch in range(N_EPOCHS):\n",
    "\n",
    "    start_time = time.time()\n",
    "    \n",
    "    train_loss, train_acc = train(model, train_iterator, optimizer, criterion)\n",
    "    valid_loss, valid_acc = evaluate(model, valid_iterator, criterion)\n",
    "    \n",
    "    end_time = time.time()\n",
    "\n",
    "    epoch_mins, epoch_secs = epoch_time(start_time, end_time)\n",
    "    \n",
    "    if valid_loss < best_valid_loss:\n",
    "        best_valid_loss = valid_loss\n",
    "        torch.save(model.state_dict(), 'tut4-model.pt')\n",
    "    \n",
    "    print(f'Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s')\n",
    "    print(f'\\tTrain Loss: {train_loss:.3f} | Train Acc: {train_acc*100:.2f}%')\n",
    "    print(f'\\t Val. Loss: {valid_loss:.3f} |  Val. Acc: {valid_acc*100:.2f}%')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "我们得到的测试结果与前2个模型结果差不多！"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test Loss: 0.343 | Test Acc: 85.31%\n"
     ]
    }
   ],
   "source": [
    "model.load_state_dict(torch.load('tut4-model.pt'))\n",
    "\n",
    "test_loss, test_acc = evaluate(model, test_iterator, criterion)\n",
    "\n",
    "print(f'Test Loss: {test_loss:.3f} | Test Acc: {test_acc*100:.2f}%')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4.4 模型验证"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "import spacy\n",
    "nlp = spacy.load('en_core_web_sm')\n",
    "\n",
    "def predict_sentiment(model, sentence, min_len = 5):\n",
    "    model.eval()\n",
    "    tokenized = [tok.text for tok in nlp.tokenizer(sentence)]\n",
    "    if len(tokenized) < min_len:\n",
    "        tokenized += ['<pad>'] * (min_len - len(tokenized))\n",
    "    indexed = [TEXT.vocab.stoi[t] for t in tokenized]\n",
    "    tensor = torch.LongTensor(indexed).to(device)\n",
    "    tensor = tensor.unsqueeze(0)\n",
    "    prediction = torch.sigmoid(model(tensor))\n",
    "    return prediction.item()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "负面评论的例子："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.09913548082113266"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "predict_sentiment(model, \"This film is terrible\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "正面评论的例子："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.9769725799560547"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "predict_sentiment(model, \"This film is great\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 小结\n",
    "\n",
    "在下一节中，我们将学习多类型情感分析。\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
